#!/usr/bin/env python
"""
Define classes represent CollapseGffRecord and GmapRecord.
Define Gff reader and writer
"""

from collections import OrderedDict
from pbcore.io import WriterBase, ReaderBase
from pbtranscript.io.FastaRandomReader import Interval

__author__ = "etseng@pacificbiosciences.com"
__SOURCE_PACBIO__ = "PacBio"


__all__ = ["CollapseGffReader", "CollapseGffWriter",
           "CollapseGffRecord", "GmapRecord"]


class GffRecordBase(object):
    """Base class representing GFF record."""
    def __init__(self, seqid, source, feature,
                 start, end, score, strand,
                 frame, attributes):
        self.seqid = seqid
        self.source = source if source != '.' else None
        self.feature = feature if feature != '.' else None
        self.start = int(start)
        self.end = int(end)
        self.score = score if score != '.' else None
        self.strand = strand
        if not self.strand in ["+", "-"]: # must be either + or -
            raise ValueError("strand %s must either be '+' or '-'." % strand)
        self.frame = frame if frame != '.' else None
        self.attributes = GffRecordBase.to_attributes(attributes)

    @classmethod
    def to_attributes(cls, attributes):
        """Returns attributes as OrderedDict."""
        if attributes is None or attributes == ".":
            return OrderedDict()
        elif isinstance(attributes, list):
            return OrderedDict(attributes)
        elif isinstance(attributes, OrderedDict):
            return attributes
        elif isinstance(attributes, str):
            ret = OrderedDict()
            for attr_str in attributes.split(';'):
                if len(attr_str) == 0:
                    continue
                fs = attr_str.strip().split(' ')
                if len(fs) != 2:
                    raise ValueError("Can not parse GFF attribute %s" % attributes)
                ret[fs[0]] = fs[1].replace('"', '')
            return ret
        else:
            raise ValueError("Invalid attributes %s" % attributes)

    @classmethod
    def fromString(cls, line):
        """Returns a GffRecordBase object by parsing line."""
        fields = line.strip().split('\t')
        if len(fields) != 9:
            raise ValueError("A GFF record %s must contain 9 fields" % line)
        return GffRecordBase(seqid=fields[0], source=fields[1], feature=fields[2],
                             start=fields[3], end=fields[4],
                             score=fields[5], strand=fields[6], frame=fields[7],
                             attributes=GffRecordBase.to_attributes(fields[8]))

    def __str__(self):
        attributes_str = " ".join(["%s \"%s\";" %
                                   (k, self.attributes[k]) for k in self.attributes])
        if len(attributes_str) == 0:
            attributes_str = '.'
        fields = [self.seqid, self.source, self.feature,
                  self.start, self.end, self.score,
                  self.strand, self.frame, attributes_str]
        return "\t".join([str(x) if x is not None else '.' for x in fields])

    def __eq__(self, other):
        return (self.seqid == other.seqid and self.source == other.source and
                self.feature == other.feature and self.start == other.start and
                self.end == other.end and self.score == other.score and
                self.strand == other.strand and self.frame == other.frame and
                self.attributes == other.attributes)


class CollapseGffRecord(GffRecordBase):
    """
    Class represents either a transcript or an exon generated by collapse_isoforms.
    """
    TRANSCRIPT = "transcript"
    EXON = "exon"
    FEATURES = (TRANSCRIPT, EXON)
    def __init__(self, seqid, start, end, feature, strand, gene_id, transcript_id):
        attributes = [("gene_id", gene_id), ("transcript_id", transcript_id)]
        super(CollapseGffRecord, self).__init__(seqid=seqid, source=__SOURCE_PACBIO__,
            feature=feature, start=start, end=end, score=None, strand=strand,
            frame=None, attributes=attributes)
        if self.feature not in self.FEATURES:
            raise ValueError("Feature %s not in %s" % (self.feature, self.FEATURES))
        self.gene_id = gene_id
        self.transcript_id = transcript_id

    @classmethod
    def fromString(cls, line):
        """Returns a CollapseGffRecord object by parsing line."""
        r = GffRecordBase.fromString(line)
        try:
            gene_id = r.attributes['gene_id']
            transcript_id = r.attributes['transcript_id']
        except KeyError:
            raise ValueError("Could not parse %s as CollapseGffRecord" % line)
        return CollapseGffRecord(seqid=r.seqid, start=r.start, end=r.end,
                                 feature=r.feature, strand=r.strand,
                                 gene_id=gene_id, transcript_id=transcript_id)

    @property
    def is_transcript(self):
        """Returns True if this Gff record is a transcript."""
        return self.feature == CollapseGffRecord.TRANSCRIPT

    @property
    def is_exon(self):
        """Returns True if this Gff record is an exon."""
        return self.feature == CollapseGffRecord.EXON


class CollapseGffWriter(WriterBase):
    """
    A GFF file writer class
    """
    def __init__(self, f):
        super(CollapseGffWriter, self).__init__(f)
        self.writeHeader("##gff-version 3")

    def writeHeader(self, headerLine):
        """Write header"""
        if not headerLine.startswith("##"):
            raise ValueError("GFF headers must start with ##")
        self.file.write("{0}\n".format(headerLine.rstrip()))

    def writeRecord(self, record):
        """Write either a CollapseGffRecord or
           a GmapRecord including a transcript and a list of exons.
        """
        if isinstance(record, CollapseGffRecord):
            self.file.write("{0}\n".format(str(record)))
        elif isinstance(record, GmapRecord):
            self.file.write("{0}\n".format(str(record.transcript_gff_record)))
            for exon in record.ref_exon_gff_records:
                self.file.write("{0}\n".format(str(exon)))


class GmapRecord(object):
    """
    FIXME
    Class represent GMAP output mapping a transcript to exons.
    """
    def __init__(self, transcript):
        """
        Record keeping for GMAP output:
        chr, coverage, identity, seqid, exons

        exons --- list of Interval: 0-based start inclusive, 0-based end, exclusive
        """
        assert isinstance(transcript, CollapseGffRecord)
        assert transcript.is_transcript

        self.transcript_gff_record = transcript

        self.chr = transcript.seqid
        self.coverage = None
        self.identity = None
        self.strand = transcript.strand
        self.seqid = transcript.transcript_id.replace("\"", "")
        self.ref_exons = []
        self.seq_exons = []
        self.scores = []

    @property
    def gene_id(self):
        """Returns gene id."""
        return self.transcript_gff_record.gene_id

    @property
    def transcript_id(self):
        """Returns transcript id."""
        return self.transcript_gff_record.transcript_id

    def __str__(self):
        return """
        chr: {chr}
        strand: {strand}
        coverage: {coverage}
        identity: {identity}
        seqid: {seqid}
        ref exons: {ref_exons}
        seq exons: {seq_exons}
        scores: {scores}
        """.format(chr=self.chr, strand=self.strand, coverage=self.coverage,
                   identity=self.identity, seqid=self.seqid,
                   ref_exons=self.ref_exons, seq_exons=self.seq_exons,
                   scores=self.scores)

    def __eq__(self, other):
        return (self.chr == other.chr and self.coverage == other.coverage and
                self.identity == other.identity and self.strand == other.strand and
                self.seqid == other.seqid and self.ref_exons == other.ref_exons and
                self.seq_exons == other.seq_exons and self.scores == other.scores)

    def __getattr__(self, key):
        if key == 'rstart' or key == 'start':
            return self.get_start()
        elif key == 'rend' or key == 'end':
            return self.get_end()
        else:
            raise AttributeError(key)

    def get_start(self):
        """Returns start position of the very first exon."""
        if len(self.ref_exons) == 0:
            raise ValueError("Could not get exon start of a transcript which has NO exon!")
        return self.ref_exons[0].start

    def get_end(self):
        """Returns end position of the very last exon."""
        if len(self.ref_exons) == 0:
            raise ValueError("Could not get exon end of a transcript which has NO exon!")
        return self.ref_exons[-1].end

    def add_exon(self, exon):
        """Add an exon CollapseGffRecord."""
        assert isinstance(exon, CollapseGffRecord)
        assert exon.is_exon
        self._add_exon(exon.start-1, exon.end, exon.start-1, exon.end, rstrand="+", score=None)

    def _add_exon(self, rStart0, rEnd1, sStart0, sEnd1, rstrand, score):
        """Add an new exon to either ref_exons or seq_exons."""
        if not (rStart0 < rEnd1 and sStart0 < sEnd1):
            raise ValueError("Invalid exon [%s, %s), [%s, %s]" % (rStart0, rEnd1, sStart0, sEnd1))
        if rstrand == '-':
            assert len(self.ref_exons) == 0 or self.ref_exons[0].start >= rEnd1
            self.scores.insert(0, score)
            self.ref_exons.insert(0, Interval(rStart0, rEnd1))
        elif rstrand == "+":
            assert len(self.ref_exons) == 0 or self.ref_exons[-1].end <= rStart0
            self.scores.append(score)
            self.ref_exons.append(Interval(rStart0, rEnd1))
        else:
            raise ValueError("Invalid strand %s" % rstrand)
        if rstrand == '-':
            self.seq_exons.insert(0, Interval(sStart0, sEnd1))
        else:
            self.seq_exons.append(Interval(sStart0, sEnd1))

    @property
    def ref_exon_gff_records(self):
        """Return reference exons as a CollapseGffRecord list."""
        ret = []
        for exon in self.ref_exons:
            ret.append(CollapseGffRecord(seqid=self.chr, feature=CollapseGffRecord.EXON,
                                         start=exon.start+1, end=exon.end, strand=self.strand,
                                         gene_id=self.seqid[:self.seqid.rfind('.')],
                                         transcript_id=self.seqid))
        return ret


class CollapseGffReader(ReaderBase):
    """
    A GFF file reader class reads PacBio-style GFF from the collapsed output.
    which is
    0) chrmosome
    1) source (PacBio)
    2) feature (transcript|exon)
    3) start (1-based)
    4) end (1-based)
    5) score (always .)
    6) strand
    7) frame (always .)
    8) blurb

    ex:
    chr1    PacBio  transcript      897326  901092  .       +       .       gene_id "PB.1"; transcript_id "PB.1.1";
    chr1    PacBio  exon    897326  897427  .       +       .       gene_id "PB.1"; transcript_id "PB.1.1";
    """
    def _readHeaders(self):
        headers = []
        firstLine = None
        for line in self.file:
            if line.startswith("##"):
                headers.append(line.rstrip())
            else:
                firstLine = line
                break
        return headers, firstLine

    def __init__(self, f):
        super(CollapseGffReader, self).__init__(f)
        self.headers, self.prevLine = self._readHeaders()

    def __iter__(self):
        return self

    def next(self):
        """Get the next GmapRecord or raise StopIteration"""
        gmap_record = None
        if not self.prevLine:
            raise StopIteration
        gff_record = CollapseGffRecord.fromString(self.prevLine)
        if not gff_record.is_transcript:
            raise ValueError("GmapRecord must begin with a transcript. %s" % self.prevLine)
        self.prevLine = None
        gmap_record = GmapRecord(transcript=gff_record)

        for line in self.file: # continue to read exons until reach the next transcript
            gff_record = CollapseGffRecord.fromString(line)
            if gff_record.is_exon:
                gmap_record.add_exon(gff_record)
            else: # reach the next transcript
                self.prevLine = line
                break
        return gmap_record
